JuliaCon 2022
Patrick Altmeyer
CounterfactualExplanations.jl in Julia (and beyond!) 📦
A recipe for disaster …
“You cannot appeal to (algorithms). They do not listen. Nor do they bend.”
— Cathy O’Neil in Weapons of Math Destruction, 2016
Figure 2: Cathy O’Neil. Source: Cathy O’Neil
Data
Probabilistic Models
Counterfactual Reasoning
Data
Probabilistic Models
Counterfactual Reasoning
Model objective: maximize \(p(\mathcal{D}|\theta)\) where \(\mathcal{D}=\{(x,y)\}_{i=1}^n\) (supervised)
[…] deep neural networks are typically very underspecified by the available data, and […] parameters [therefore] correspond to a diverse variety of compelling explanations for the data. (Wilson 2020)
Data
Probabilistic Models
Counterfactual Reasoning
Even though […] interpretability is of great importance and should be pursued, explanations can, in principle, be offered without opening the “black box”. (Wachter, Mittelstadt, and Russell 2017)
\[ \min_{x\prime \in \mathcal{X}} h(x\prime) \ \ \ \mbox{s. t.} \ \ \ M(x\prime) = t \qquad(1)\]
\[ x\prime = \arg \min_{x\prime} \ell(M(x\prime),t) + \lambda h(x\prime) \qquad(2)\]
We have fitted some black-box classifier to divide cats and dogs. One 🐱 is friends with a lot of cool 🐶 and wants to remain part of that group. The counterfactual path below shows her how to fool the classifier:
Yes and no!
While both are methodologically very similar, adversarial examples are meant to go undetected while CEs ought to be meaningful.
Effective counterfactuals should meet certain criteria ✅
closeness: the average distance between factual and counterfactual features should be small (Wachter, Mittelstadt, and Russell (2017))
actionability: the proposed feature perturbation should actually be actionable (Ustun, Spangher, and Liu (2019), Poyiadzi et al. (2020))
plausibility: the counterfactual explanation should be plausible to a human (Joshi et al. (2019))
unambiguity: a human should have no trouble assigning a label to the counterfactual (Schut et al. (2021))
sparsity: the counterfactual explanation should involve as few individual feature changes as possible (Schut et al. (2021))
robustness: the counterfactual explanation should be robust to domain and model shifts (Upadhyay, Joshi, and Lakkaraju (2021))
diversity: ideally multiple diverse counterfactual explanations should be provided (Mothilal, Sharma, and Tan (2020))
causality: counterfactual explanations reflect the structural causal model underlying the data generating process (Karimi et al. (2020), Karimi, Schölkopf, and Valera (2021))
NO!
Well, maybe …
There is nonetheless an intriguing link between the two domains:
When people say that counterfactuals should look realistic or plausible, they really mean that counterfactuals should be generated by the same Data Generating Process (DGP) as the factuals:
\[ x\prime \sim p(x) \]
CounterfactualExplanations.jl 📦Julia has an edge with respect to Trustworthy AI: it’s open-source, uniquely transparent and interoperable 🔴🟢🟣
Modular, composable, scalable!
For more detail, check out the repo directly
# Data:
using CounterfactualExplanations.Data
Random.seed!(1234)
N = 25
xs, ys = Data.toy_data_linear(N)
X = hcat(xs...)
counterfactual_data = CounterfactualData(X,ys')
# Model
using CounterfactualExplanations.Models: LogisticModel, probs
# Logit model:
w = [1.0 1.0] # true coefficients
b = 0
M = LogisticModel(w, [b])
# Randomly selected factual:
Random.seed!(123)
x = select_factual(counterfactual_data,rand(1:size(X)[2]))
y = round(probs(M, x)[1])
target = ifelse(y==1.0,0.0,1.0) # opposite label as target
# Counterfactual search:
generator = GenericGenerator()
counterfactual = generate_counterfactual(x, target, counterfactual_data, M, generator)# Model:
using LinearAlgebra
Σ = Symmetric(reshape(randn(9),3,3).*0.01 + UniformScaling(1)) # MAP covariance matrix
μ = hcat(b, w)
M = CounterfactualExplanations.Models.BayesianLogisticModel(μ, Σ)
# Counterfactual search:
generator = GreedyGenerator(;δ=0.1,n=25))
counterfactual = generate_counterfactual(x, target, counterfactual_data, M, generator)# Model:
using LinearAlgebra
Σ = Symmetric(reshape(randn(9),3,3).*0.01 + UniformScaling(1)) # MAP covariance matrix
μ = hcat(b, w)
M = CounterfactualExplanations.Models.BayesianLogisticModel(μ, Σ)
# Counterfactual search:
generator = GreedyGenerator(;δ=0.1,n=25))
counterfactual = generate_counterfactual(x, target, counterfactual_data, M, generator)But things can go wrong …
using Flux, RCall
using CounterfactualExplanations, CounterfactualExplanations.Models
import CounterfactualExplanations.Models: logits, probs # import functions in order to extend
# Step 1)
struct TorchNetwork <: Models.AbstractFittedModel
model::Any
end
# Step 2)
function logits(M::TorchNetwork, X::AbstractArray)
nn = M.model
ŷ = rcopy(R"as_array($nn(torch_tensor(t($X))))")
ŷ = isa(ŷ, AbstractArray) ? ŷ : [ŷ]
return ŷ'
end
probs(M::TorchNetwork, X::AbstractArray)= σ.(logits(M, X))
M = TorchNetwork(R"model")import CounterfactualExplanations.Generators: ∂ℓ
using LinearAlgebra
# Countefactual loss:
function ∂ℓ(generator::AbstractGradientBasedGenerator, counterfactual_state::CounterfactualState)
M = counterfactual_state.M
nn = M.model
x′ = counterfactual_state.x′
t = counterfactual_state.target_encoded
R"""
x <- torch_tensor($x′, requires_grad=TRUE)
output <- $nn(x)
obj_loss <- nnf_binary_cross_entropy_with_logits(output,$t)
obj_loss$backward()
"""
grad = rcopy(R"as_array(x$grad)")
return grad
end# Abstract suptype:
abstract type AbstractDropoutGenerator <: AbstractGradientBasedGenerator end
# Constructor:
struct DropoutGenerator <: AbstractDropoutGenerator
loss::Symbol # loss function
complexity::Function # complexity function
mutability::Union{Nothing,Vector{Symbol}} # mutibility constraints
λ::AbstractFloat # strength of penalty
ϵ::AbstractFloat # step size
τ::AbstractFloat # tolerance for convergence
p_dropout::AbstractFloat # dropout rate
end
# Instantiate:
using LinearAlgebra
generator = DropoutGenerator(
:logitbinarycrossentropy,
norm,
nothing,
0.1,
0.1,
1e-5,
0.5
)import CounterfactualExplanations.Generators: generate_perturbations, ∇
using StatsBase
function generate_perturbations(generator::AbstractDropoutGenerator, counterfactual_state::CounterfactualState)
𝐠ₜ = ∇(generator, counterfactual_state) # gradient
# Dropout:
set_to_zero = sample(1:length(𝐠ₜ),Int(round(generator.p_dropout*length(𝐠ₜ))),replace=false)
𝐠ₜ[set_to_zero] .= 0
Δx′ = - (generator.ϵ .* 𝐠ₜ) # gradient step
return Δx′
endFlux, torch, tensorflow) and other differentiable models.What happens once AR has actually been implemented? 👀
JuliaCon 2022 - Explaining Black-Box Models through Counterfactuals